{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TVM 样例\n",
    "\n",
    "参考：[TVM 编译简单的神经网络](https://tvm.apache.org/docs/get_started/tutorials/quick_start.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Apache TVM 是遵循 Python 优先开发、通用部署原则的机器学习编译框架。它接收预训练的机器学习模型，编译并生成可嵌入和在任何地方运行的部署模块。Apache TVM 还允许自定义优化过程，引入新的优化、库、代码生成等。\n",
    "\n",
    "Apache TVM 可以帮助：\n",
    "\n",
    "- **优化**：ML 工作负载、组合库和代码生成的性能。\n",
    "- **部署**：ML 工作负载部署到一组不同的新环境，包括新运行时和新的硬件。\n",
    "- **持续改进和定制**：通过快速自定义在 Python 中部署 ML 管道 库调度，引入自定义算子和代码生成。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "整体流程包括以下步骤：\n",
    "\n",
    "- **构建或导入模型**：构建神经网络模型或从其他框架（例如 PyTorch、ONNX）导入预训练模型，并创建 TVM IRModule，其中包含编译所需的所有信息，包括用于计算图的高级 Relax 函数和用于张量程序的低级 TensorIR 函数。\n",
    "- **执行可组合优化**：执行一系列优化变换，如计算图优化、张量程序优化和库调度。\n",
    "- **构建和通用部署**：将优化后的模型构建为可部署模块到通用运行时，并在不同设备上执行，如 CPU、GPU 或其他加速器。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构造或导入模型\n",
    "\n",
    "使用 TVM Relax 前端直接定义两层的感知器（MLP）网络，其 API 与 PyTorch 类似。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relax\n",
    "from tvm.relax.frontend import nn\n",
    "\n",
    "class MLPModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(784, 256)\n",
    "        self.relu1 = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu1(x)\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以将模型导出为 TVM IRModule，这是 TVM 中的核心中间表示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">forward</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
       "            permute_dims: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">784</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>permute_dims(fc1_weight, axes<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
       "            matmul: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>matmul(x, permute_dims, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            add: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(matmul, fc1_bias)\n",
       "            relu: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>nn<span style=\"color: #A2F; font-weight: bold\">.</span>relu(add)\n",
       "            permute_dims1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>permute_dims(fc2_weight, axes<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">None</span>)\n",
       "            matmul1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>matmul(relu, permute_dims1, out_dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;void&quot;</span>)\n",
       "            add1: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>add(matmul1, fc2_bias)\n",
       "            gv: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #A2F; font-weight: bold\">=</span> add1\n",
       "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = MLPModel()\n",
    "mod, param_spec = model.export_tvm(\n",
    "    spec={\"forward\": {\"x\": nn.spec.Tensor((1, 784), \"float32\")}}\n",
    ")\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 执行优化转换\n",
    "\n",
    "Apache TVM 利用 `pipeline` 变换和优化程序。该管道封装了一系列变换，实现两个目标（在同一层级）：\n",
    "\n",
    "1. 模型优化：例如算子融合、布局重写。\n",
    "2. 张量程序优化：将算子映射到底层实现（包括库或代码生成）\n",
    "\n",
    "```{note}\n",
    "这两个是目标，而不是 `pipeline` 的阶段。这两种优化是在同一层级进行的，或者在两个阶段分别进行。\n",
    "```\n",
    "\n",
    "在本教程中，只演示整体流程，通过利用 `zero` 优化管道，而不是针对任何特定目标进行优化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = relax.get_pipeline(\"zero\")(mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #A2F\">@I</span><span style=\"color: #A2F; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #00F; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func(private<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">fused_matmul1_add1</span>(relu: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), permute_dims1: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_bias: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>),), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), T_add_intermediate: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        matmul_intermediate <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>alloc_buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)))\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1, k <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;matmul&quot;</span>):\n",
       "                v_i0, v_i1, v_k <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i0, i1, k])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(matmul_intermediate[v_i0, v_i1])\n",
       "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>init():\n",
       "                    matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
       "                matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">+</span> relu[v_i0, v_k] <span style=\"color: #A2F; font-weight: bold\">*</span> permute_dims1[v_k, v_i1]\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0, ax1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;T_add&quot;</span>):\n",
       "                v_ax0, v_ax1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [ax0, ax1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(T_add_intermediate[v_ax0, v_ax1])\n",
       "                T_add_intermediate[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">=</span> matmul_intermediate[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">+</span> fc2_bias[v_ax1]\n",
       "\n",
       "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func(private<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">fused_matmul_add_relu</span>(x: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), permute_dims: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_bias: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>),), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), compute_intermediate: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        matmul_intermediate <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>alloc_buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)))\n",
       "        T_add_intermediate <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>alloc_buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)))\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1, k <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;matmul&quot;</span>):\n",
       "                v_i0, v_i1, v_k <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SSR&quot;</span>, [i0, i1, k])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(x[v_i0, v_k], permute_dims[v_k, v_i1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(matmul_intermediate[v_i0, v_i1])\n",
       "                <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>init():\n",
       "                    matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>)\n",
       "                matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> matmul_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">+</span> x[v_i0, v_k] <span style=\"color: #A2F; font-weight: bold\">*</span> permute_dims[v_k, v_i1]\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0, ax1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;T_add&quot;</span>):\n",
       "                v_ax0, v_ax1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [ax0, ax1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(T_add_intermediate[v_ax0, v_ax1])\n",
       "                T_add_intermediate[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">=</span> matmul_intermediate[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">+</span> fc1_bias[v_ax1]\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> i0, i1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">1</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;compute&quot;</span>):\n",
       "                v_i0, v_i1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [i0, i1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(T_add_intermediate[v_i0, v_i1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(compute_intermediate[v_i0, v_i1])\n",
       "                compute_intermediate[v_i0, v_i1] <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>max(T_add_intermediate[v_i0, v_i1], T<span style=\"color: #A2F; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">0.0</span>))\n",
       "\n",
       "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func(private<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">transpose</span>(fc1_weight: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), T_transpose: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;op_pattern&quot;</span>: <span style=\"color: #008000\">2</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0, ax1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">784</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;T_transpose&quot;</span>):\n",
       "                v_ax0, v_ax1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [ax0, ax1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(fc1_weight[v_ax1, v_ax0])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(T_transpose[v_ax0, v_ax1])\n",
       "                T_transpose[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">=</span> fc1_weight[v_ax1, v_ax0]\n",
       "\n",
       "    <span style=\"color: #A2F\">@T</span><span style=\"color: #A2F; font-weight: bold\">.</span>prim_func(private<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #008000; font-weight: bold\">True</span>)\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">transpose1</span>(fc2_weight: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), T_transpose: T<span style=\"color: #A2F; font-weight: bold\">.</span>Buffer((T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
       "        T<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;op_pattern&quot;</span>: <span style=\"color: #008000\">2</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: <span style=\"color: #008000; font-weight: bold\">True</span>})\n",
       "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
       "        <span style=\"color: #008000; font-weight: bold\">for</span> ax0, ax1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>grid(T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">256</span>), T<span style=\"color: #A2F; font-weight: bold\">.</span>int64(<span style=\"color: #008000\">10</span>)):\n",
       "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;T_transpose&quot;</span>):\n",
       "                v_ax0, v_ax1 <span style=\"color: #A2F; font-weight: bold\">=</span> T<span style=\"color: #A2F; font-weight: bold\">.</span>axis<span style=\"color: #A2F; font-weight: bold\">.</span>remap(<span style=\"color: #BA2121\">&quot;SS&quot;</span>, [ax0, ax1])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>reads(fc2_weight[v_ax1, v_ax0])\n",
       "                T<span style=\"color: #A2F; font-weight: bold\">.</span>writes(T_transpose[v_ax0, v_ax1])\n",
       "                T_transpose[v_ax0, v_ax1] <span style=\"color: #A2F; font-weight: bold\">=</span> fc2_weight[v_ax1, v_ax0]\n",
       "\n",
       "    <span style=\"color: #A2F\">@R</span><span style=\"color: #A2F; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #00F\">forward</span>(x: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">784</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc1_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_weight: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>), fc2_bias: R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">10</span>,), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #A2F; font-weight: bold\">-&gt;</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        R<span style=\"color: #A2F; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;num_input&quot;</span>: <span style=\"color: #008000\">1</span>})\n",
       "        cls <span style=\"color: #A2F; font-weight: bold\">=</span> Module\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>dataflow():\n",
       "            permute_dims <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #A2F; font-weight: bold\">.</span>transpose, (fc1_weight,), out_sinfo<span style=\"color: #A2F; font-weight: bold\">=</span>R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">784</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            lv <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #A2F; font-weight: bold\">.</span>fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo<span style=\"color: #A2F; font-weight: bold\">=</span>R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">256</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            permute_dims1 <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #A2F; font-weight: bold\">.</span>transpose1, (fc2_weight,), out_sinfo<span style=\"color: #A2F; font-weight: bold\">=</span>R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">256</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            gv <span style=\"color: #A2F; font-weight: bold\">=</span> R<span style=\"color: #A2F; font-weight: bold\">.</span>call_tir(cls<span style=\"color: #A2F; font-weight: bold\">.</span>fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_sinfo<span style=\"color: #A2F; font-weight: bold\">=</span>R<span style=\"color: #A2F; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">10</span>), dtype<span style=\"color: #A2F; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>))\n",
       "            R<span style=\"color: #A2F; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建和通用部署\n",
    "\n",
    "优化完成后，将模型构建为可部署模块，并在不同设备上运行。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[24260.729 23490.016 25546.697 25084.074 24464.416 23487.184 23823.516\n",
      "  23294.605 24136.322 23890.729]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "target = tvm.target.Target(\"llvm\")\n",
    "ex = tvm.compile(mod, target)\n",
    "device = tvm.cpu()\n",
    "vm = tvm.runtime.vm.VirtualMachine(ex, device)\n",
    "data = np.random.rand(1, 784).astype(\"float32\")\n",
    "tvm_data = tvm.runtime.tensor(data, device=device)\n",
    "params = [np.random.rand(*param.shape).astype(\"float32\") for _, param in param_spec]\n",
    "params = [tvm.runtime.tensor(param, device=device) for param in params]\n",
    "print(vm[\"forward\"](tvm_data, *params).numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TVM 的目标是将机器学习带入任何感兴趣语言的应用程序中，同时提供最小的运行时支持。\n",
    "\n",
    "- 在 IRModule 中的每个函数成为运行时的可运行函数。例如，在 LLM 情况下，可以直接调用 `prefill` 和 `decode` 函数。\n",
    "    ```python\n",
    "    prefill_logits = vm[\"prefill\"](inputs, weight, kv_cache)\n",
    "    decoded_logits = vm[\"decode\"](inputs, weight, kv_cache)\n",
    "    ```\n",
    "- TVM 运行时附带原生数据结构，如 NDArray，也可以与现有生态系统（通过 DLPack 与 PyTorch 交换）进行零拷贝交换。\n",
    "    ```python\n",
    "    # 将 PyTorch 张量转换为 TVM NDArray\n",
    "    x_tvm = tvm.nd.from_dlpack(x_torch.to_dlpack())\n",
    "    # 转换 TVM NDArray 为 PyTorch 张量\n",
    "    x_torch = torch.from_dlpack(x_tvm.to_dlpack())\n",
    "    ```\n",
    "- TVM 运行时在非 Python 环境中工作，因此它可以在移动设备等设置上运行。\n",
    "    ```cpp\n",
    "    // C++ snippet\n",
    "    runtime::Module vm = ex.GetFunction(\"load_executable\")();\n",
    "    vm.GetFunction(\"init\")(...);\n",
    "    NDArray out = vm.GetFunction(\"prefill\")(data, weight, kv_cache);\n",
    "    ```\n",
    "    \n",
    "    ```java\n",
    "    // Java snippet\n",
    "    Module vm = ex.getFunction(\"load_executable\").invoke();\n",
    "    vm.getFunction(\"init\").pushArg(...).invoke;\n",
    "    NDArray out = vm.getFunction(\"prefill\").pushArg(data).pushArg(weight).pushArg(kv_cache).invoke();\n",
    "    ```\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aix",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
